matplotlib

import matplotlib
import matplotlib.pyplot as plt
plt.plot([1, 2, 4, 9, 5, 3])
plt.show()

그렇습니다. 데이터 몇 개로 plot 함수를 호출한 다음, show 함수를 호출해주면 간단히 그래프를 그려볼 수 있습니다!

plot 함수에 단일 배열의 데이터가 주어진다면, 수직 축의 좌표로서 이를 사용하게 되며, 각 데이터의 배열상 색인(인덱스)을 수평 좌표로서 사용합니다. 두 개의 배열을 넣어줄 수도 있습니다: 그러면, 하나는 x 축에 대한것이며, 다른 하나는 y 축에 대한것이 됩니다:

  • 같은 그림을 object oriented API를 이용해 그려보겠습니다.
  • object oriented API는 그래프의 각 부분을 객체로 지정하고 그리는 것으로, 다음과 같은 패턴을 가지고 있습니다. 아래 코드와 주석의 # object oriented API 부분은 이제현이 추가한 부분입니다.
  • object oriented API와 구분하기 위해 원본 코드에는 #pyplot이라는 헤더를 달았습니다.)
fig, ax = plt.subplots()
ax.plot([1,2,3,4,5])
fig # 도화지
# fig = 종이, ax = 그림 
fig, ax = plt.subplots()

# 2. ax 위에 그래프를 그립니다.
ax.plot([1, 2, 4, 9, 5, 3])

# 3. 그래프를 화면에 출력합니다.
plt.show()
  • pyplot과 동일한 형태의 그래프가 그려집니다.
  • fig, ax를 선언하느라 한 줄을 더 입력해야 한다는 불편함이 있지만 ax 객체가 있어 그래프를 제어하기 더 쉬워집니다.
  • 많은 경우 fig, ax = plt.subplots() 대신 ax = plt.subplot()으로 해도 됩니다.
  • 그러나 fig 대상 명령(예. savefig)을 사용해야 할 때도 있고, 두 가지를 따로 외우려면 혼동이 되니 한 가지로 통일하는 것이 좋습니다.
plt.plot([-3, -2, 5, 0], [1, 6, 4, 3]) # 앞쪽괄호가 x축, 뒷쪽이 y축
plt.show()
fig, ax = plt.subplots()
ax.plot([-3, -2, 5, 0], [1, 6, 4, 3])# 앞쪽괄호가 x축, 뒷쪽이 y축
plt.show()

이번에는 수학적인 함수를 그려보겠습니다. NumPy의 linespace 함수를 사용하여 -2 ~ 2 범위에 속하는 500개의 부동소수로 구성된 x 배열을 생성합니다. 그 다음 x의 각 값의 거듭제곱된 값을 포함하는 y 배열을 생성합니다 (NumPy에 대하여 좀 더 알고 싶다면, NumPy 튜토리얼을 참고하시기 바랍니다).

import numpy as np
x = np.linspace(-2, 2, 500)  # -2부터 2(포함)까지 500개로 나눈 것(포인트를 많이 나열해서 선같아보이는 것)
y = x**2

plt.plot(x, y)
plt.show()
fig, ax = plt.subplots()

ax.plot(x, y)

plt.show()
plt.plot(x, y)
plt.title("Square function") # 상단 제목
plt.xlabel("x") # x축 제목
plt.ylabel("y = x**2") # y축 제목
plt.grid(True) # 격자무늬
plt.show()
  • object-oriented API는 축 이름과 같은 설정 명령어가 pyplot과 다소 다릅니다.
  • 대체로 축 이름(label), 범위(limits) 등을 지정하는 명령어는 set_대상(), 거꾸로 그래프에서 설정값을 가져오는 명령어는 get_대상()으로 통일되어 있습니다.
  • 개인적으로 pyplot의 명령어 체계보다 object-oriented API의 체계를 선호합니다.
fig, ax = plt.subplots()

# 위와 비슷하지만 plt.이 아닌 ax.set_을 이용해 그래프를 커스텀할 수 있다
ax.plot(x, y)
ax.set_title("Square function")
ax.set_xlabel("x")
ax.set_ylabel("y = x**2")
ax.grid(True)

plt.show()

선의 스타일과 색상

기본적으로 matplotlib은 바로 다음에 위치한(연이은) 데이터 사이에 선을 그립니다

# 위에서와 마찬가지로 앞쪽이 x축, 뒷쪽이 y축
plt.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 130, 100, 0])
plt.axis([-10, 110, -10, 140]) # x,y축의 범위를 지정
plt.show()
fig, ax = plt.subplots()

# 위와 동일, 앞쪽이 x축, 뒷쪽이 y축
ax.plot([0, 100, 100, 0, 0, 100, 50, 0, 100], [0, 0, 100, 100, 0, 100, 130, 100, 0])
ax.set_xlim(-10, 110) # 최소 최대 지정
ax.set_ylim(-10, 140) # 최소 최대 지정

# 그래프의 범위는 pyplot과 같이 ax.axis([-10, 110, -10, 140]) 으로 지정할 수 있습니다.
# 하지만 위와 같이 set_xlim, set_ylim을 사용해서 명시하는 것이 더 체계적으로 느껴집니다.

plt.show()

세 번째 파라미터를 지정하면 선의 스타일과 색상을 바꿀 수 있습니다. 예를 들어서 "g--"는 "초록색 파선"을 의미합니다.
예를 들어 아래와 같이 말이죠:

# "r-" 앞에 있는 좌표로 빨간 실선을 그리고
# "g--"앞에 있는 좌표로 초록 점선을 그려준다
plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-", [0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140]) # x,y축 설정
plt.show()
fig, ax = plt.subplots()

# "r-" 앞에 있는 좌표로 빨간 실선을 그리고
# "g--"앞에 있는 좌표로 초록 점선을 그려준다
ax.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-", [0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
ax.set_xlim(-10, 110) # x축 범위 설정
ax.set_ylim(-10, 140) # y축 범위 설정

plt.show()

또는 show를 호출하기 전 plot을 여러번 호출해도 가능합니다.

# 여러 개의 그림을 한번에 나타내기

plt.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-")
plt.plot([0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
plt.axis([-10, 110, -10, 140])
plt.show()
fig, ax = plt.subplots()

ax.plot([0, 100, 100, 0, 0], [0, 0, 100, 100, 0], "r-")
ax.plot([0, 100, 50, 0, 100], [0, 100, 130, 100, 0], "g--")
ax.set_xlim(-10, 110)
ax.set_ylim(-10, 140)

plt.show()

선 대신에 간단한 점을 그려보는 것도 가능합니다. 아래는 초록색 파선, 빨강 점선, 파랑 삼각형의 예를 보여줍니다. 공식 문서에서 사용 가능한 스타일 및 색상의 모든 옵션을 확인해 볼 수 있습니다.

x = np.linspace(-1.4, 1.4, 30) # -1.4,1.4를 포함하며 그 사이의 값을 30개로 나눠서그림
plt.plot(x, x, 'g--', x, x**2, 'r:', x, x**3, 'b^') 
# g--앞의 것을 초록실선으로 , r:앞에 있는 것을 빨간 점이 이어지 모양으로, b^앞에 있는 것을 파란 삼각형이 이어지게 그려준다
plt.show()
fig, ax = plt.subplots()

x = np.linspace(-1.4, 1.4, 30) # -1.4부터 1.4(포함)를 30개로 나눠라

ax.plot(x, x, 'g--') # x값을 이용하여 
ax.plot(x, x**2, 'r:')
ax.plot(x, x**3, 'b^')

# 여러 그래프를 ax.plot(x, x, 'g--', x, x**2, 'r:', x, x**3, 'b^')과 같이 한 줄에 그릴 수도 있습니다.
# 그러나 이와 같이 따로 떼서 그리면 혼동을 방지할 수 있습니다.
# 이는 pyplot도 마찬가지입니다.

plt.show()

plot 함수는 Line2D객체로 구성된 리스트를 반환합니다 (각 객체가 각 선에 대응됩니다). 이 선들에 대한 추가적인 속성을 설정할 수도 있습니다. 가령 선의 두께, 스타일, 투명도 같은것의 설정이 가능합니다. 공식 문서에서 설정 가능한 모든 속성을 확인해볼 수 있습니다.

x = np.linspace(-1.4, 1.4, 30)
line1, line2, line3 = plt.plot(x, x, 'g--', x, x**2, 'r:', x, x**3, 'b^')
line1.set_linewidth(3.0) # 더 두껍게 해준듯
line1.set_dash_capstyle("round") # 대쉬 스타일을 바꿔서 둥글둥글한 모양으로 바꿈
line3.set_alpha(0.2) # 투명도
plt.show()
x = np.linspace(-1.4, 1.4, 30)

fig, ax = plt.subplots()

# plot을 나누어 그리면 어디에 어떤 설정이 적용되었는지 알아보기 편합니다.
# linewidth, alpha와 같은 line style도 plot() 안에 넣으면 혼동을 방지할 수 있습니다.
line1 = ax.plot(x, x, 'g--', linewidth=3, dash_capstyle='round') # 위에보다 선이 굵어지고 모양이 더 동글동글 해짐 대시 스타일을 round로 바꿨기 때문
line2 = ax.plot(x, x**2, 'r:')
line3 = ax.plot(x, x**3, 'b^', alpha=0.2) # 투명도 추가

plt.show()

그림 저장

그래프를 그림파일로 저장하는 방법은 간단합니다. 단순히 파일이름을 지정하여 savefig 함수를 호출해 주기만 하면 됩니다. 가능한 이미지 포맷은 사용하는 그래픽 백엔드에 따라서 지원 여부가 결정됩니다.

x = np.linspace(-1.4, 1.4, 30)
plt.plot(x, x**2)
plt.savefig("my_square_function.png", transparent=True) # 저장, transparent = True는 배경을 투명하게 한다는 뜻

부분 그래프 (subplot)

matplotlib는 하나의 그림(figure)에 여러개의 부분 그래프를 포함할 수 있습니다. 이 부분 그래프는 격자 형식으로 관리됩니다. subplot 함수를 호출하여 부분 그래프를 생성할 수 있습니다. 이 때 격자의 행/열의 수 및 그래프를 그리고자 하는 부분 그래프의 색인을 파라미터로서 지정해줄 수 있습니다 (색인은 1부터 시작하며, 좌->우, 상단->하단의 방향입니다).

  • pyplot은 현재 활성화된 부분 그래프를 계속해서 추적합니다 (plt.gca()를 호출하여 해당 부분 그래프의 참조를 얻을 수 있습니다). 따라서, plot 함수를 호출할 때 활성화된 부분 그래프에 그림이 그려지게 됩니다.
  • object oriented API 방식에서는 그래프를 그리기 전에 먼저 틀을 잡아둡니다. 그래프를 그릴 때 사전에 정의된 영역 중 어디에 그래프를 그릴지 지정하는 방식입니다.
  • pyplotplt.gca()가 바로 object oriented API의 axes입니다.
x = np.linspace(-1.4, 1.4, 30)

# subplot(2,2,1)은 subplot(221)로 축약할 수 있습니다.
plt.subplot(2, 2, 1)  # 2 행 2 열 크기의 격자 중 첫 번째 부분 그래프 = 좌측 상단
plt.plot(x, x)
plt.subplot(2, 2, 2)  # 2 행 2 열 크기의 격자 중 두 번째 부분 그래프 = 우측 상단
plt.plot(x, x**2)
plt.subplot(2, 2, 3)  # 2 행 2 열 크기의 격자 중 세 번째 부분 그래프 = 좌측 하단
plt.plot(x, x**3)
plt.subplot(2, 2, 4)  # 2 행 2 열 크기의 격자 중 네 번째 부분 그래프 = 우측 하단
plt.plot(x, x**4)
plt.show()
x = np.linspace(-1.4, 1.4, 30)

fig, ax = plt.subplots(2, 2) # 순서대로 row의 갯수, col의 갯수입니다. nrows=2, cols=2로 지정할 수도 있습니다.

# plot위치는 ax[row, col] 또는 ax[row][col]로 지정합니다.
ax[0, 0].plot(x, x)      # 2 행 2 열 크기의 격자 중 첫 번째 부분 그래프 = 좌측 상단
ax[0, 1].plot(x, x**2)   # 2 행 2 열 크기의 격자 중 두 번째 부분 그래프 = 우측 상단
ax[1, 0].plot(x, x**3)   # 2 행 2 열 크기의 격자 중 세 번째 부분 그래프 = 좌측 하단
ax[1, 1].plot(x, x**4)   # 2 행 2 열 크기의 격자 중 네 번째 부분 그래프 = 우측 하단

plt.show()

격자의 여러 영역으로 확장된 부분 그래프를 생성하는 것도 쉽습니다:

# 밑에 2개 열을 차지하는 부분 주의하기 line8에 쓰여져 있는 주석 읽고 잘 이해하기

plt.subplot(2, 2, 1)  # 2 행 2 열 크기의 격자 중 첫 번째 부분 그래프 = 좌측 상단
plt.plot(x, x)
plt.subplot(2, 2, 2)  # 2 행 2 열 크기의 격자 중 두 번째 부분 그래프 = 우측 상단
plt.plot(x, x**2)
plt.subplot(2, 1, 2)  # 2행 *1* 열의 두 번째 부분 그래프 = 하단
                      # 2행 1열 크기의 그래프가 두 개 그려질 수 있지만,
                      # 상단 부분은 이미 두 개의 부분 그래프가 차지하였다.
                      # 따라서, 두 번째 부분 그래프로 지정함
plt.plot(x, x**3)
plt.show()
grid = plt.GridSpec(2, 2)  # 2행 2열 크기의 격?자를 준비합니다.

ax1 = plt.subplot(grid[0, 0])  # 2행 2열 크기의 격자 중 첫 번째 부분 그래프 = 좌측 상단
ax2 = plt.subplot(grid[0, 1])  # 2행 2열 크기의 격자 중 두 번째 부분 그래프 = 우측 상단
ax3 = plt.subplot(grid[1, 0:]) # 2행 *1*열의 두 번째 부분 그래프 = 하단
                               # 범위를 [1, 0:]으로 설정하여 2행 전체를 지정함.

ax1.plot(x, x)
ax2.plot(x, x**2)
ax3.plot(x, x**3)

plt.show()

보다 복잡한 부분 그래프의 위치 선정이 필요하다면, subplot2grid를 대신 사용할 수 있습니다. 격자의 행과 열의 번호 및 격자에서 해당 부분 그래프를 그릴 위치를 지정해줄 수 있습니다 (좌측상단 = (0,0). 또한 몇 개의 행/열로 확장되어야 하는지도 추가적으로 지정할 수 있습니다. 아래는 그에 대한 예를 보여줍니다:

# 그림 그리면서 보여주신 것
# rowspan = 2라면 행방향으로 2칸, colspan=2라면 열 방향으로 2칸

plt.subplot2grid((3,3), (0, 0), rowspan=2, colspan=2) # 3x3 매트릭스에서 0행0열에 해당하는 자리에서 rowspan과 colspan이 모두 2로 0행0열을 기준으로 2x2매트릭스만큼 자치하겠다는 뜻
plt.plot(x, x**2)
plt.subplot2grid((3,3), (0, 2)) # 3x3 매트릭스에서 0행2열에 해당하는 자리에 그림(rowspan이나 colspan을 따로 지정하지 않을 때는 해당 칸만 씀)
plt.plot(x, x**3)
plt.subplot2grid((3,3), (1, 2), rowspan=2) # 3x3 매트릭스에서 1행2열에 해당하는 자리에서 rowspan = 2니까 1행2열과 2행2열에 해당 그래프를 그림
plt.plot(x, x**4)
plt.subplot2grid((3,3), (2, 0), colspan=2) # 3x3 매트릭스에서 2행0열에 해당하는 자리에서 colspan = 2니까 2행 0열과 2행 1열을 자치하게 그래프 그림 
plt.plot(x, x**5)
plt.show()
gridsize = (3, 3)     # 2행 2열 크기의 격자를 준비합니다.
ax1 = plt.subplot2grid(gridsize, (0,0), rowspan=2, colspan=2)
ax2 = plt.subplot2grid(gridsize, (0,2))
ax3 = plt.subplot2grid(gridsize, (1,2), rowspan=2)
ax4 = plt.subplot2grid(gridsize, (2,0), colspan=2)

ax1.plot(x, x**2)
ax2.plot(x, x**3)
ax3.plot(x, x**4)
ax4.plot(x, x**5)

plt.show()

보다 유연한 부분그래프 위치선정이 필요하다면, GridSpec 문서를 확인해 보시길 바랍니다.

여러개의 그림 (figure)

여러개의 그림을 그리는것도 가능합니다. 각 그림은 하나 이상의 부분 그래프를 가질 수 있습니다. 기본적으로는 matplotlib이 자동으로 figure(1)을 생성합니다. 그림간 전환을 할 때, pyplot은 현재 활성화된 그림을 계속해서 추적합니다 (이에대한 참조는 plt.gcf()의 호출로 알 수 있습니다). 또한 활성화된 그림의 활성화된 부분 그래프가 현재 그래프가 그려질 부분 그래프가 됩니다.

  • object oriented API에서는 실행 순이 아니라 객체를 중심으로 명령을 실행합니다.
  • 다른 그림을 그리다가 앞서 그림을 추가할 때 pyplot에서 plt.figure() 명령으로 위 그림을 호출하는 대신 object oriented API는 목표 Axes를 지정하여 추가합니다.
x = np.linspace(-1.4, 1.4, 30)

plt.figure(1)
plt.subplot(211) #plt.subplot(211)는 plt.subplot(2,1,1)와 같다
plt.plot(x, x**2)
plt.title("Square and Cube")
plt.subplot(212)
plt.plot(x, x**3)

plt.figure(2, figsize=(10, 5))
plt.subplot(121)
plt.plot(x, x**4)
plt.title("y = x**4")
plt.subplot(122)
plt.plot(x, x**5)
plt.title("y = x**5")

plt.figure(1)      # 그림 1로 돌아가며, 활성화된 부분 그래프는 212 (하단)이 됩니다*********
plt.plot(x, -x**3, "r:")

plt.show()
x = np.linspace(-1.4, 1.4, 30)

fig1, ax1 = plt.subplots(nrows=2, ncols=1)

ax1[0].plot(x, x**2)
ax1[0].set_title("Square and Cube")

ax1[1].plot(x, x**3)


fig2, ax2 = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax2[0].plot(x, x**4)
ax2[0].set_title("y = x**4")

ax2[1].plot(x, x**5)
ax2[1].set_title("y = x**5")

ax1[1].plot(x, -x**3, "r:")    # 그림 1로 돌아가며, 활성화된 부분 그래프는 ax1[1] (하단)이 됩니다.
plt.show()
  • 여기서 설명하는 부분은 matplotlibobject oriented API(객체지향 인터페이스)입니다.
x = np.linspace(-2, 2, 200)
fig1, (ax_top, ax_bottom) = plt.subplots(2, 1, sharex=True)
fig1.set_size_inches(10,5)
line1, line2 = ax_top.plot(x, np.sin(3*x**2), "r-", x, np.cos(5*x**2), "b-") # line1은 파란색 선, line2는 빨간색 여기까지 그리면 맨 위에 그림 만들어짐

line3, = ax_bottom.plot(x, np.sin(3*x), "r-") # 두 번째 그래프 (빨간색)
ax_top.grid(True) 

fig2, ax = plt.subplots(1, 1) # 맨 마지막 그래프 (하늘색)
ax.plot(x, x**2)
plt.show()

일관성을 위해서 이 튜토리얼의 나머지 부분에서는 pyplot의 상태 머신을 계속해서 사용할 것입니다. 그러나 프로그램에서는 객체지향 인터페이스의 사용을 권장하고 싶습니다.

Pylab vs Pyplot vs Matplotlib

pylab, pyplot, matplotlib 간의 관계에대한 혼동이 있습니다. 그러나 이들의 관계는 매우 단순합니다: matplotlib은 완전한 라이브러리이며, pylab 및 pyplot을 포함한 모든것을 가지고 있습니다.

Pyplot은 그래프를 그리기위한 다양한 도구를 제공합니다. 여기에는 내부적인 객체지향적인 그래프 그리기 라이브러리에 대한 상태 머신 인터페이스도 포함됩니다.

Pylab은 mkatplotlib.pyplot 및 NumPy를 단일 네임스페이스로 임포트하는 편리성을 위한 모듈입니다. 인터넷에 떠도는 pylab을 사용하는 여러 예제를 보게 될 것입니다. 그러나 이는 더이상 권장되는 사용방법은 아닙니다 (왜냐하면 명시적인 임포트가 암시적인것 보다 더 낫기 때문입니다).

  • Pylab, Pyplot, Object oriented API의 관계는 여기를 참고하십시오

텍스트 그리기

text 함수를 호출하여 텍스트를 그래프의 원하는 위치에 추가할 수 있습니다. 출력을 원하는 텍스트와 수평 및 수직 좌표를 지정하고, 추가적으로 몇 가지 속성을 지정해 주기만 하면 됩니다. matplotlib의 모든 텍스트는 TeX 방정식 표현을 포함할 수 있습니다. 더 자세한 내용은 공식 문서를 참조하시기 바랍니다.

x = np.linspace(-1.5, 1.5, 30)
px = 0.8
py = px**2

plt.plot(x, x**2, "b-", px, py, "ro") # 여기서 점을 찍음, 그래프 선은 파란색으로 하고 좌표px,py에 (빨간)점을 찍는 코드

plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='blue', horizontalalignment="center")
plt.text(px - 0.08, py, "Beautiful point", ha="right", weight="heavy")
plt.text(px, py, "x = %0.2f\ny = %0.2f"%(px, py), rotation=50, color='gray')

plt.show()
fig, ax = plt.subplots()

ax.plot(x, x**2, "b-") # x축은 x, y축은 x**2 따라서 y=x**2를 그린다
ax.plot(px, py, "ro") # 점 찍힘

ax.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='blue', horizontalalignment="center")
ax.text(px - 0.08, py, "Beautiful point", ha="right", weight="heavy")
ax.text(px, py, "x = %0.2f\ny = %0.2f"%(px, py), rotation=50, color='gray')

plt.show()
  • 알아둘 것: hahorizontalalignment(수평정렬)의 이명 입니다.

더 많은 텍스트 속성을 알고 싶다면, 공식 문서를 참조하시기 바랍니다.

아래 그래프의 "beautiful point" 같은 텍스트 처럼, 그래프의 요소에 주석을 다는것은 꽤 흔한 일입니다. annotate 함수는 이를 쉽게 할 수 있게 해 줍니다: 관심있는 부분의 위치를 지정하고, 텍스트의 위치를 지정합니다. 그리고 텍스트 및 화살표에 대한 추가적인 속성도 지정해줄 수 있습니다.

plt.plot(x, x**2, px, py, "ro")
plt.annotate("Beautiful point", xy=(px, py), xytext=(px-1.3,py+0.5),
                           color="green", weight="heavy", fontsize=14,
                           arrowprops={"facecolor": "lightgreen"})
plt.show()
fig, ax = plt.subplots()
ax.plot(x, x**2, px, py, "ro")
ax.annotate("Beautiful point", xy=(px, py), xytext=(px-1.3,py+0.5),
                           color="green", weight="heavy", fontsize=14,
                           arrowprops={"facecolor": "lightgreen"})
plt.show()

bbox 속성을 사용하면, 텍스트를 포함하는 사각형을 그려볼 수도 있습니다:

# 빨간점
plt.plot(x, x**2, px, py, "ro")

# 화살표 부분
bbox_props = dict(boxstyle="rarrow,pad=0.3", ec="b", lw=2, fc="lightblue") 
plt.text(px-0.2, py, "Beautiful point", bbox=bbox_props, ha="right")

# 네모 박스 부분
bbox_props = dict(boxstyle="round4,pad=1,rounding_size=0.2", ec="black", fc="#EEEEFF", lw=5) 
plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='black', ha="center", bbox=bbox_props)

plt.show()
fig, ax = plt.subplots()
ax.plot(x, x**2)
ax.plot(px, py, "ro")

bbox_props = dict(boxstyle="rarrow,pad=0.3", ec="b", lw=2, fc="lightblue") # 화살표 부분
ax.text(px-0.2, py, "Beautiful point", bbox=bbox_props, ha="right")

bbox_props = dict(boxstyle="round4,pad=1,rounding_size=0.2", ec="black", fc="#EEEEFF", lw=5) # 네모 박스 부분
ax.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='black', ha="center", bbox=bbox_props)

plt.show()

재미를 위해서 xkcd 스타일의 그래프를 그려보고 싶다면, with plt.xkcd() 섹션 블록을 활용할 수도 있습니다:

with plt.xkcd():
    plt.plot(x, x**2, px, py, "ro")

    bbox_props = dict(boxstyle="rarrow,pad=0.3", ec="b", lw=2, fc="lightblue")
    plt.text(px-0.2, py, "Beautiful point", bbox=bbox_props, ha="right")

    bbox_props = dict(boxstyle="round4,pad=1,rounding_size=0.2", ec="black", fc="#EEEEFF", lw=5)
    plt.text(0, 1.5, "Square function\n$y = x^2$", fontsize=20, color='black', ha="center", bbox=bbox_props)

    plt.show()

범례 (Legends)

범례를 추가하는 가장 간단한 방법은 모든 선에 라벨을 설정 해 주고, legend 함수를 호출하는 것입니다.

x = np.linspace(-1.4, 1.4, 50)
plt.plot(x, x**2, "r--", label="Square function")
plt.plot(x, x**3, "g-", label="Cube function")
plt.legend(loc="best")
plt.grid(True)
plt.show()
x = np.linspace(-1.4, 1.4, 50)

fig, ax = plt.subplots()

ax.plot(x, x**2, "r--", label="Square function")
ax.plot(x, x**3, "g-", label="Cube function")
ax.legend(loc="best")
ax.grid(True)
plt.show()

비선형 척도

Matplotlib은 로그, 로짓(logit)과 같은 비선형 척도를 지원합니다.

# 로그그래프 그릴줄 알아야함
x = np.linspace(0.1, 15, 500)
y = x**3/np.exp(2*x)

plt.figure(1)
plt.plot(x, y)
plt.yscale('linear')
plt.title('linear')
plt.grid(True)

plt.figure(2)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)

plt.figure(3)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)

plt.figure(4)
plt.plot(x, y - y.mean())
plt.yscale('symlog', linthreshy=0.05)
plt.title('symlog')
plt.grid(True)

plt.show()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
c:\Users\user\Desktop\강수인\2022-05-13-matplotlib-seaborn-kangsuin.ipynb Cell 65' in <cell line: 27>()
     <a href='vscode-notebook-cell:/c%3A/Users/user/Desktop/%EA%B0%95%EC%88%98%EC%9D%B8/2022-05-13-matplotlib-seaborn-kangsuin.ipynb#ch0000117?line=24'>25</a> plt.figure(4)
     <a href='vscode-notebook-cell:/c%3A/Users/user/Desktop/%EA%B0%95%EC%88%98%EC%9D%B8/2022-05-13-matplotlib-seaborn-kangsuin.ipynb#ch0000117?line=25'>26</a> plt.plot(x, y - y.mean())
---> <a href='vscode-notebook-cell:/c%3A/Users/user/Desktop/%EA%B0%95%EC%88%98%EC%9D%B8/2022-05-13-matplotlib-seaborn-kangsuin.ipynb#ch0000117?line=26'>27</a> plt.yscale('symlog', linthreshy=0.05)
     <a href='vscode-notebook-cell:/c%3A/Users/user/Desktop/%EA%B0%95%EC%88%98%EC%9D%B8/2022-05-13-matplotlib-seaborn-kangsuin.ipynb#ch0000117?line=27'>28</a> plt.title('symlog')
     <a href='vscode-notebook-cell:/c%3A/Users/user/Desktop/%EA%B0%95%EC%88%98%EC%9D%B8/2022-05-13-matplotlib-seaborn-kangsuin.ipynb#ch0000117?line=28'>29</a> plt.grid(True)

File c:\ProgramData\Anaconda3\lib\site-packages\matplotlib\pyplot.py:3055, in yscale(value, **kwargs)
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/pyplot.py?line=3052'>3053</a> @_copy_docstring_and_deprecators(Axes.set_yscale)
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/pyplot.py?line=3053'>3054</a> def yscale(value, **kwargs):
-> <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/pyplot.py?line=3054'>3055</a>     return gca().set_yscale(value, **kwargs)

File c:\ProgramData\Anaconda3\lib\site-packages\matplotlib\axes\_base.py:4108, in _AxesBase.set_yscale(self, value, **kwargs)
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axes/_base.py?line=4105'>4106</a> g = self.get_shared_y_axes()
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axes/_base.py?line=4106'>4107</a> for ax in g.get_siblings(self):
-> <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axes/_base.py?line=4107'>4108</a>     ax.yaxis._set_scale(value, **kwargs)
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axes/_base.py?line=4108'>4109</a>     ax._update_transScale()
   <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axes/_base.py?line=4109'>4110</a>     ax.stale = True

File c:\ProgramData\Anaconda3\lib\site-packages\matplotlib\axis.py:761, in Axis._set_scale(self, value, **kwargs)
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axis.py?line=758'>759</a> def _set_scale(self, value, **kwargs):
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axis.py?line=759'>760</a>     if not isinstance(value, mscale.ScaleBase):
--> <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axis.py?line=760'>761</a>         self._scale = mscale.scale_factory(value, self, **kwargs)
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axis.py?line=761'>762</a>     else:
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/axis.py?line=762'>763</a>         self._scale = value

File c:\ProgramData\Anaconda3\lib\site-packages\matplotlib\scale.py:597, in scale_factory(scale, axis, **kwargs)
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/scale.py?line=594'>595</a>     scale = scale.lower()
    <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/scale.py?line=595'>596</a> scale_cls = _api.check_getitem(_scale_mapping, scale=scale)
--> <a href='file:///c%3A/ProgramData/Anaconda3/lib/site-packages/matplotlib/scale.py?line=596'>597</a> return scale_cls(axis, **kwargs)

TypeError: __init__() got an unexpected keyword argument 'linthreshy'

틱과 틱커 (Ticks and tickers)

각 축에는 "틱(ticks)"이라는 작은 표시가 있습니다. 정확히 말하자면, "틱"은 표시(예. (-1, 0, 1))의 위치"이며, 틱 선은 그 위치에 그려지는 작은 선입니다. 또한 "틱 라벨"은 틱 선 옆에 그려지는 라벨이며, "틱커"는 틱의 위치를 결정하는 객체 입니다. 기본적인 틱커는 ~5 에서 8 틱을 위치시키는데 꽤 잘 작동합니다. 즉, 틱 서로간에 적당한 거리를 표현합니다.

하지만, 가끔은 좀 더 이를 제어할 필요가 있습니다 (예. 위의 로짓 그래프에서는 너무 많은 틱 라벨이 있습니다). 다행히도 matplotlib은 틱을 완전히 제어하는 방법을 제공합니다. 심지어 보조 눈금(minor tick)을 활성화 할 수도 있습니다.

# 이제현 주: 사실상 object oriented API 입니다.

x = np.linspace(-2, 2, 100)

plt.figure(1, figsize=(15,10))
plt.subplot(131)
plt.plot(x, x**3)
plt.grid(True)
plt.title("Default ticks")

ax = plt.subplot(132)
plt.plot(x, x**3)
ax.xaxis.set_ticks(np.arange(-2, 2, 1))
plt.grid(True)
plt.title("Manual ticks on the x-axis")

ax = plt.subplot(133)
plt.plot(x, x**3)
plt.minorticks_on()
ax.tick_params(axis='x', which='minor', bottom='off')
ax.xaxis.set_ticks([-2, 0, 1, 2])
ax.yaxis.set_ticks(np.arange(-5, 5, 1))
ax.yaxis.set_ticklabels(["min", -4, -3, -2, -1, 0, 1, 2, 3, "max"])
plt.title("Manual ticks and tick labels\n(plus minor ticks) on the y-axis")


plt.grid(True)

plt.show()
# 위 pyplot 예제는 사실상 object oriented API 입니다.
# 여기에서는 같은 기능을 더 단순한 코드로 구현하였습니다

x = np.linspace(-2, 2, 100)

fig, ax = plt.subplots(ncols=3, figsize=(15, 10))

ax[0].plot(x, x**3)
ax[0].grid(True)
ax[0].set_title("Default ticks")

ax[1].plot(x, x**3)
ax[1].grid(True)
ax[1].set_xticks(np.arange(-2, 2, 1))
ax[1].set_title("Manual ticks on the x-axis")

ax[2].plot(x, x**3)
ax[2].grid(True)
ax[2].minorticks_on()
ax[2].set_xticks([-2, 0, 1, 2], minor=False)
ax[2].set_yticks(np.arange(-5, 5, 1))
ax[2].set_yticklabels(["min", -4, -3, -2, -1, 0, 1, 2, 3, "max"])
ax[2].set_title("Manual ticks and tick labels\n(plus minor ticks) on the y-axis")

plt.show()

극좌표계의 투영 (Polar projection)

극좌표계 그래프를 그리는 것은 매우 간단합니다. 부분 그래프를 생성할 때 projection 속성을 "polar"로 설정해 주기만 하면 됩니다.

  • object oriented API는 일반적으로 plt.subplots()FigureAxes 객체를 동시에 생성합니다.
  • plt.subplots()projection 속성을 가지고 있지 않습니다.
  • 따라서 projection을 사용할 때는 plt.figure()Figure 객체를 먼저 생성한 후 plt.subplot()이나 plt.add_subplot()으로 Axes 객체를 추가해 주거나, fig.subplots() 안에 subplot_kw=={'polar':True}로 지정해 주어야 합니다.
radius = 1
theta = np.linspace(0, 2*np.pi*radius, 1000)

plt.subplot(111, projection='polar')
plt.plot(theta, np.sin(5*theta), "g-")
plt.plot(theta, 0.5*np.cos(20*theta), "b-")
plt.show()
radius = 1
theta = np.linspace(0, 2*np.pi*radius, 1000)

fig = plt.figure()
ax = fig.add_subplot(projection='polar')

# 또는, subplot_kw 를 이용해서 polar plot으로 설정합니다.
# fig, ax = plt.subplots(subplot_kw={'polar':True}) 

ax.plot(theta, np.sin(5*theta), "g-")
ax.plot(theta, 0.5*np.cos(20*theta), "b-")
plt.show()

3차원 투영

3차원 그래프를 그리는것은 꽤 간단합니다. 우선 "3d" 투영을 등록하는 Axes3D를 임포트 해줘야 합니다. 그리곤 projection 속성을 "3d"로 설정된 부분 그래프 생성합니다. 그러면 Axes3DSubplot 이라는 객체가 반환되는데, 이 객체의 plot_surface 메서드를 호출하면 x, y, z 좌표를 포함한 추가적이나 속성을 지정할 수 있습니다.

 
# 사실상 object oriented API 입니다.

from mpl_toolkits.mplot3d import Axes3D

x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)

figure = plt.figure(1, figsize = (12, 4))
subplot3d = plt.subplot(111, projection='3d')  # Axes 객체입니다. 
surface = subplot3d.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=matplotlib.cm.coolwarm, linewidth=0.1)
plt.show()

동일한 데이터를 출력하는 또 다른 방법은 등고선도(contour plot)를 이용하는 것입니다.

plt.contourf(X, Y, Z, cmap=matplotlib.cm.coolwarm)
plt.colorbar()
plt.show()
# 이제현 주: 종종 object oriented API가 pyplot보다 불편할 때가 있습니다.
#            contour plot의 colorbar는 무엇을 대상으로 할 지를 인자로 전달해야 합니다.

fig, ax = plt.subplots()    
contour = ax.contourf(X, Y, Z, cmap=matplotlib.cm.coolwarm)
plt.colorbar(contour)
plt.show()

산점도(Scatter plot)

단순히 각 점에 대한 x 및 y 좌표를 제공하면 산점도를 그릴 수 있습니다.

from numpy.random import rand
x, y = rand(2, 100)
plt.scatter(x, y)
plt.show()
from numpy.random import rand
x, y = rand(2, 100)

fig, ax = plt.subplots()
ax.scatter(x, y)
plt.show()

부수적으로 각 점의 크기를 정할 수도 있습니다.

x, y, scale = rand(3, 100)
scale = 500 * scale ** 5
plt.scatter(x, y, s=scale)
plt.show()
x, y, scale = rand(3, 100)
scale = 500 * scale ** 5

fig, ax = plt.subplots()
ax.scatter(x, y, s=scale)
plt.show()

마찬가지로 여러 속성을 설정할 수 있습니다. 가령 테두리 및 모양의 내부 색상, 그리고 투명도와 같은것의 설정이 가능합니다.

# 점이 300개 찍힘 n=100이고 루프를 3번돌리기 때문
for color in ['red', 'green', 'blue']:
    n = 100
    x, y = rand(2, n)
    scale = 500.0 * rand(n) ** 5
    plt.scatter(x, y, s=scale, c=color, alpha=0.3, edgecolors='blue')

plt.grid(True)

plt.show()
fig, ax = plt.subplots()

for color in ['red', 'green', 'blue']:
    n = 100
    x, y = rand(2, n)
    scale = 500.0 * rand(n) ** 5
    ax.scatter(x, y, s=scale, c=color, alpha=0.3, edgecolors='blue')

ax.grid(True)

plt.show()

지금까지 해온것 처럼 plot 함수를 사용하여 선을 그릴 수 있습니다. 하지만, 가끔은 그래프를 통과하는 무한한 선을 그리는 유틸리티 함수를 만들면 편리합니다 (기울기와 절편으로). 또한 hlinesvlines 함수를 사용하면, 아래와 같이 부분 수평 및 수직 선을 그릴 수도 있습니다:

from numpy.random import randn

def plot_line(axis, slope, intercept, **kargs):
    xmin, xmax = axis.get_xlim()
    plt.plot([xmin, xmax], [xmin*slope+intercept, xmax*slope+intercept], **kargs) # x값은[xmin, xmax], y값은[xmin*slope+intercept, xmax*slope+intercept]

x = randn(1000)
y = 0.5*x + 5 + randn(1000)*2
plt.axis([-2.5, 2.5, -5, 15])
plt.scatter(x, y, alpha=0.2)
plt.plot(1, 0, "ro")
plt.vlines(1, -5, 0, color="red") # 수직선
plt.hlines(0, -2.5, 1, color="red") # 수평선
plot_line(axis=plt.gca(), slope=0.5, intercept=5, color="magenta")
plt.grid(True)
plt.show()
from numpy.random import randn

# Axis를 인자로 전달하여 함수 연산과 시각화를 수행합니다.
def plot_line(axis, slope, intercept, **kargs):
    xmin, xmax = axis.get_xlim()
    axis.plot([xmin, xmax], [xmin*slope+intercept, xmax*slope+intercept], **kargs)

x = randn(1000)
y = 0.5*x + 5 + randn(1000)*2

fig, ax = plt.subplots()

ax.set_xlim(-2.5, 2.5)
ax.set_ylim(-5, 15)
ax.scatter(x, y, alpha=0.2)
ax.plot(1, 0, "ro")
ax.vlines(1, -5, 0, color="red")
ax.hlines(0, -2.5, 1, color="red")
plot_line(axis=ax, slope=0.5, intercept=5, color="magenta")  
ax.grid(True)
plt.show()

히스토그램

data = [1, 1.1, 1.8, 2, 2.1, 3.2, 3, 3, 3, 3]
plt.subplot(211)
plt.hist(data, bins = 10, rwidth=0.8)

plt.subplot(212)
plt.hist(data, bins = [1, 1.5, 2, 2.5, 3], rwidth=0.95)
plt.xlabel("Value")
plt.ylabel("Frequency")

plt.show()
data = [1, 1.1, 1.8, 2, 2.1, 3.2, 3, 3, 3, 3]

fig, ax = plt.subplots(2, 1)
ax[0].hist(data, bins = 10, rwidth=0.8)

ax[1].hist(data, bins = [1, 1.5, 2, 2.5, 3], rwidth=0.95)
ax[1].set_xlabel("Value")
ax[1].set_ylabel("Frequency")

plt.show()
data1 = np.random.randn(400)
data2 = np.random.randn(500) + 3
data3 = np.random.randn(450) + 6
data4a = np.random.randn(200) + 9
data4b = np.random.randn(100) + 10

plt.hist(data1, bins=5, color='g', alpha=0.75, label='bar hist') # default histtype='bar'
plt.hist(data2, color='b', alpha=0.65, histtype='stepfilled', label='stepfilled hist')
plt.hist(data3, color='r', histtype='step', label='step hist')
plt.hist((data4a, data4b), color=('r','m'), alpha=0.55, histtype='barstacked', label=('barstacked a', 'barstacked b'))

plt.xlabel("Value")
plt.ylabel("Frequency")
plt.legend()
plt.grid(True)
plt.show()
data1 = np.random.randn(400)
data2 = np.random.randn(500) + 3
data3 = np.random.randn(450) + 6
data4a = np.random.randn(200) + 9
data4b = np.random.randn(100) + 10

fig, ax = plt.subplots()
ax.hist(data1, bins=5, color='g', alpha=0.75, label='bar hist') # default histtype='bar'
ax.hist(data2, color='b', alpha=0.65, histtype='stepfilled', label='stepfilled hist')
ax.hist(data3, color='r', histtype='step', label='step hist')
ax.hist((data4a, data4b), color=('r','m'), alpha=0.55, histtype='barstacked', label=('barstacked a', 'barstacked b'))

ax.set_xlabel("Value")
ax.set_ylabel("Frequency")
ax.legend()
ax.grid(True)
plt.show()

이미지

matplotlib에서의 이미지 불러오기, 생성하기, 화면에 그리기는 꽤 간단합니다.

이미지를 불러오려면 matplotlib.image 모듈을 임포트하고, 파일이름을 지정한 imread 함수를 호출해 주면 됩니다. 그러면 이미지 데이터가 NumPy의 배열로서 반환됩니다. 앞서 저장했던 my_square_function.png 이미지에 대하여 이를 수행해 보겠습니다.

  • 이미지 단독 출력은 pyplotobject oriented API 사이에 별 차이가 없습니다.
  • Axes를 지정해서 출력하는 것이 다를 뿐입니다.
  • pyplot과의 중복성이 강하지만 익숙해지는 차원에서 object oriented API를 함께 도시합니다.
import matplotlib.image as mpimg

img = mpimg.imread('my_square_function.png')
print(img.shape, img.dtype)
(288, 432, 4) float32

288x432 크기의 이미지를 불러왔습니다. 각 픽셀은 0~1 사이의 32비트 부동소수 값인 4개의 요소(빨강, 초록, 파랑, 투명도)로 구성된 배열로 표현됩니다. 이번에는 imshow함수를 호출해 보겠습니다:

plt.imshow(img)
plt.show()
fig, ax = plt.subplots()
ax.imshow(img)
plt.show()

이미지 출력에 포함된 축을 숨기고 싶다면 아래와 같이 축을 off 시켜줄 수 있습니다:

plt.imshow(img)
plt.axis('off')
plt.show()
fig, ax = plt.subplots()
ax.imshow(img)
ax.axis('off')
plt.show()

여러분만의 이미지를 생성하는것도 마찬가지로 간단합니다:

img = np.arange(100*100).reshape(100, 100)
print(img)
plt.imshow(img)
plt.show()
[[   0    1    2 ...   97   98   99]
 [ 100  101  102 ...  197  198  199]
 [ 200  201  202 ...  297  298  299]
 ...
 [9700 9701 9702 ... 9797 9798 9799]
 [9800 9801 9802 ... 9897 9898 9899]
 [9900 9901 9902 ... 9997 9998 9999]]
img = np.arange(100*100).reshape(100, 100)
print(img)

fig, ax = plt.subplots()
ax.imshow(img)
plt.show()
[[   0    1    2 ...   97   98   99]
 [ 100  101  102 ...  197  198  199]
 [ 200  201  202 ...  297  298  299]
 ...
 [9700 9701 9702 ... 9797 9798 9799]
 [9800 9801 9802 ... 9897 9898 9899]
 [9900 9901 9902 ... 9997 9998 9999]]

RGB 수준을 제공하지 않는다면, imshow 함수는 자동으로 값을 색그래디언트에 매핑합니다. 기본적인 동작에서의 색그래디언트는 파랑(낮은 값) 에서 빨강(높은 값)으로 움직입니다. 하지만 아래와 같이 다른 색상맵을 선택할 수도 있습니다:

plt.imshow(img, cmap="hot")
plt.show()
fig, ax = plt.subplots()
ax.imshow(img, cmap="hot")
plt.show()

RGB 이미지를 직접적으로 생성하는것 또한 가능합니다:

img = np.empty((20,30,3))
img[:, :10] = [0, 0, 0.6]
img[:, 10:20] = [1, 1, 1]
img[:, 20:] = [0.6, 0, 0]
plt.imshow(img, interpolation='bilinear')
plt.show()
img = np.empty((20,30,3))
img[:, :10] = [0, 0, 0.6]
img[:, 10:20] = [1, 1, 1]
img[:, 20:] = [0.6, 0, 0]

fig, ax = plt.subplots()
ax.imshow(img, interpolation='bilinear')
plt.show()

img 배열이 매우 작기 때문에 (20x30), imshow 함수는 이미지를 figure 크기에 맞도록 늘려버린채 출력합니다. 이러한 늘리기의 기본 동작은 쌍선형 보간법(bilinear interpolation)을 사용하여 추가된 픽셀을 매꿉니다. 테두리가 흐릿한 이유입니다.

다른 보간법 알고리즘을 선택할 수도 있습니다. 가령 아래와 같이 근접 픽셀을 복사하는 방법이 있습니다:

  • 위 코드의 ax.imshow(img, interpolation='bilinear') 부분은 원문에서 ax.imshow(img)로 되어 있습니다.
  • matplotlib 2.0 이전에는 interpolation='bilinear'가 기본값이기 때문에 경계선이 흐려지는 문제가 있었습니다.
  • 그러나 이후 interpolation='nearest'로 기본값이 변경되어 흐려지는 문제가 더 이상 발생하지 않습니다.
  • 자세한 사항은 이 글을 참고하십시오.
plt.imshow(img, interpolation="nearest")
plt.show()
fig, ax = plt.subplots()
ax.imshow(img, interpolation="nearest")
plt.show()

애니메이션

matplotlib은 이미지 생성에 주로 사용되지만, 애니메이션의 출력도 가능합니다. 우선 matplotlib.animation을 임포트 해 줘야 합니다. 그 다음은 (주피터 노트북에서) nbagg를 백엔드로 설정하거나, 아래의 코드를 실행해 주면 됩니다.

import matplotlib.animation as animation
matplotlib.rc('animation', html='jshtml')

다음의 예는 데이터를 생성하는것으로 시작됩니다. 그 다음, 빈 그래프를 생성하고, 애니메이션을 그릴 매 프레임 마다 호출될 갱신(update) 함수를 정의합니다. 마지막으로, FuncAnimation 인스턴스를 생성하여 그래프에 애니메이션을 추가합니다.

FuncAnimation 생성자는 figure, 갱신 함수, 그 외의 파라미터를 수용합니다. 각 프레임간 20ms의 시간차가 있는 100개의 프레임으로 구성된 애니메이션에 대한 인스턴스를 만들었습니다. 애니메이션의 각 프레임마다 FuncAnimation 는 갱신 함수를 호출하고, 프레임 번호를 num (이 예에서는 0~99의 범위) 으로서 전달해 줍니다. 또한 갱신 함수의 추가적인 두 파라미터는 FuncAnimation 생성시 fargs에 넣어준 값이 됩니다.

작성한 갱신 함수는 선을 구성하는 데이터를 0 ~ num 데이터로 설정합니다 (따라서 데이터가 점진적으로 그려집니다). 그리고 약간의 재미 요소를 위해서, 각 데이터에 약간의 무작위 수를 추가하여 선이 씰룩씰룩 움직이게끔 해 주었습니다.

x = np.linspace(-1, 1, 100)
y = np.sin(x**2*25)
data = np.array([x, y])

fig = plt.figure()
line, = plt.plot([], [], "r-") # start with an empty plot
plt.axis([-1.1, 1.1, -1.1, 1.1])
plt.plot([-0.5, 0.5], [0, 0], "b-", [0, 0], [-0.5, 0.5], "b-", 0, 0, "ro")
plt.grid(True)
plt.title("Marvelous animation")

# this function will be called at every iteration
def update_line(num, data, line):
    line.set_data(data[..., :num] + np.random.rand(2, num) / 25)  # we only plot the first `num` data points.
    return line,

line_ani = animation.FuncAnimation(fig, update_line, frames=100, fargs=(data, line), interval=67)
plt.close()
line_ani
</input>
x = np.linspace(-1, 1, 100)
y = np.sin(x**2*25)
data = np.array([x, y])

fig, ax = plt.subplots()

line, = ax.plot([], [], "r-") # start with an empty plot
ax.set_xlim(-1.1, 1.1)
ax.set_ylim(-1.1, 1.1)
ax.plot([-0.5, 0.5], [0, 0], "b-", [0, 0], [-0.5, 0.5], "b-", 0, 0, "ro")
ax.grid(True)
ax.set_title("Marvelous animation")

# this function will be called at every iteration
def update_line(num, data, line):
    line.set_data(data[..., :num] + np.random.rand(2, num) / 25)  # we only plot the first `num` data points.
    return line,

line_ani = animation.FuncAnimation(fig, update_line, frames=100, fargs=(data, line), interval=67)
plt.close()
line_ani
</input>

seaborn

import seaborn as sns
sns.set()
sns.set(style="darkgrid")


import numpy as np
import pandas as pd

# importing matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")
plt.rcParams['figure.figsize']=(10,10)
data_BM = pd.read_csv('bigmart_data.csv')
# drop the null values
data_BM = data_BM.dropna(how="any")
# multiply Item_Visibility by 100 to increase size
data_BM["Visibility_Scaled"] = data_BM["Item_Visibility"] * 100
# view the top results
data_BM.head()
Item_Identifier Item_Weight Item_Fat_Content Item_Visibility Item_Type Item_MRP Outlet_Identifier Outlet_Establishment_Year Outlet_Size Outlet_Location_Type Outlet_Type Item_Outlet_Sales Visibility_Scaled
0 FDA15 9.300 Low Fat 0.016047 Dairy 249.8092 OUT049 1999 Medium Tier 1 Supermarket Type1 3735.1380 1.604730
1 DRC01 5.920 Regular 0.019278 Soft Drinks 48.2692 OUT018 2009 Medium Tier 3 Supermarket Type2 443.4228 1.927822
2 FDN15 17.500 Low Fat 0.016760 Meat 141.6180 OUT049 1999 Medium Tier 1 Supermarket Type1 2097.2700 1.676007
4 NCD19 8.930 Low Fat 0.000000 Household 53.8614 OUT013 1987 High Tier 3 Supermarket Type1 994.7052 0.000000
5 FDP36 10.395 Regular 0.000000 Baking Goods 51.4008 OUT018 2009 Medium Tier 3 Supermarket Type2 556.6088 0.000000

1. 기본 플롯 만들기

matplotlib에서 여러 줄이 필요한 한 줄로 seaborn에서 몇 가지 기본 플롯을 만드는 방법을 살펴보겠습니다.

라인 차트

  • 일부 데이터 세트의 경우 한 변수의 변화를 시간의 함수로 이해하거나 이와 유사한 연속 변수를 이해하고 싶을 수 있습니다.
  • seaborn에서 이는 lineplot() 함수로 직접 또는 kind="line"을 설정하여 relplot()으로 수행할 수 있습니다.
sns.lineplot(x="Item_Weight", y="Item_MRP",data=data_BM[:50]);

막대 차트

  • Seaborn에서는 barplot 기능을 이용하여 간단하게 막대그래프를 생성할 수 있습니다.
  • matplotlib에서 동일한 결과를 얻으려면 데이터 범주를 현명하게 그룹화하기 위해 추가 코드를 작성해야 했습니다.
  • 그리고 나서 플롯이 올바르게 나오도록 훨씬 더 많은 코드를 작성해야 했습니다.
sns.barplot(x="Item_Type", y="Item_MRP", data=data_BM[:5])
<AxesSubplot:xlabel='Item_Type', ylabel='Item_MRP'>

히스토그램

  • distplot()을 사용하여 Seaborn에서 히스토그램을 생성할 수 있습니다. 사용할 수 있는 여러 옵션이 있으며 노트북에서 더 자세히 살펴보겠습니다.
sns.distplot(data_BM['Item_MRP'])
<AxesSubplot:xlabel='Item_MRP', ylabel='Density'>

상자 플롯

  • Seaborn에서 boxplot을 생성하기 위해 boxplot()을 사용할 수 있습니다.
  • 아이템의 Item_Outlet_Sales 분포를 시각화해 봅시다.
sns.boxplot(data_BM['Item_Outlet_Sales'], orient='vertical')
# 2000에 몰려있음, 오른쪽으로 긴 꼬리를 가지는 분포
<AxesSubplot:xlabel='Item_Outlet_Sales'>

바이올린 플롯

  • 바이올린 플롯은 상자 및 whisker 플롯과 유사한 역할을 합니다.
  • 이러한 분포를 비교할 수 있도록 하나(또는 그 이상) 범주형 변수의 여러 수준에 걸친 정량적 데이터의 분포를 보여줍니다.
  • 모든 플롯 구성 요소가 실제 데이터 포인트에 해당하는 상자 플롯과 달리 바이올린 플롯은 기본 분포의 커널 밀도 추정을 특징으로 합니다.
  • Seaborn에서 violinplot()을 사용하여 바이올린 플롯을 만들 수 있습니다.
sns.violinplot(data_BM['Item_Outlet_Sales'], orient='vertical', color='magenta')
<AxesSubplot:xlabel='Item_Outlet_Sales'>

scatter plot(산점도)

  • 각 포인트는 데이터 세트의 관찰을 나타내는 포인트 클라우드를 사용하여 두 변수의 분포를 나타냅니다.
  • 이 묘사를 통해 눈은 그들 사이에 의미 있는 관계가 있는지 여부에 대한 상당한 양의 정보를 추론할 수 있습니다.
  • relplot()kind=scatter 옵션과 함께 사용하여 seaborn에서 산점도를 그릴 수 있습니다.

참고: 여기에서는 플롯에 대한 데이터의 하위 집합만 사용할 것입니다.

sns.relplot(x="Item_MRP", y="Item_Outlet_Sales", data=data_BM[:200], kind="scatter");

Hue semantic(색조 의미)

세 번째 변수에 따라 점을 색칠하여 플롯에 다른 차원을 추가할 수도 있습니다. Seaborn에서는 이것을 "색조 의미론" 사용이라고 합니다.

sns.relplot(x="Item_MRP", y="Item_Outlet_Sales", hue="Item_Type",data=data_BM[:200]);

hue(색조) 의미 체계를 사용하면 seaborn에서 더 복잡한 선 플롯을 만들 수 있습니다.

  • 다음 예에서는 Outlet_Size의 다른 범주에 대해 다른 선 플롯이 만들어집니다.
sns.lineplot(x="Item_Weight", y="Item_MRP",hue='Outlet_Size',data=data_BM[:150]);

Bubble plot(버블 플롯)

  • hue(색조) 시맨틱을 활용하여 Item_Visibility별로 거품을 색칠함과 동시에 개별 거품의 크기로 사용합니다.
sns.relplot(x="Item_MRP", y="Item_Outlet_Sales", data=data_BM[:200], kind="scatter", size="Visibility_Scaled", hue="Visibility_Scaled");

카테고리별 하위 플롯

  • Seaborn에서 카테고리별 플롯을 생성할 수도 있습니다.
  • 각 Outlet_Size에 대한 산점도를 만들었습니다
sns.relplot(x="Item_Weight", y="Item_Visibility",hue='Outlet_Size',style='Outlet_Size',col='Outlet_Size',data=data_BM[:100]);
# hue로 컬러 설정
# ggplot에서도 비슷한 함수 있음

2. seaborn의 고급 범주형 플롯

범주형 변수의 경우 seaborn에 세 가지 다른 패밀리가 있습니다.

  • 범주형 산점도:

    • stripplot() (kind="strip" 사용, 기본값)
    • swarmplot() (종류="swarm" 포함)
  • 범주 분포도:

    • boxplot() (종류="box" 포함)
    • 바이올린플롯() (종류="바이올린" 포함)
    • boxenplot() (종류="boxen" 사용)
  • 범주 추정 플롯:

    • pointplot() (종류="포인트" 사용)
    • barplot() (종류="막대" 사용)

catplot()에서 데이터의 기본 표현은 산점도를 사용합니다.

a. 범주형 산점도

스트립 플롯

  • 하나의 변수가 범주형인 산점도를 그립니다.
  • catplot()에서 kind=strip을 전달하여 생성할 수 있습니다.
sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales", kind='strip',data=data_BM[:250]);

스웜 플롯

  • 이 함수는 stripplot()과 유사하지만 포인트가 겹치지 않도록 조정됩니다(범주 축을 따라만).
  • 이것은 값의 분포를 더 잘 표현하지만 많은 수의 관찰에 대해서는 잘 확장되지 않습니다. 이러한 스타일의 플롯은 때때로 "꿀벌"이라고 불립니다.
  • catplot()에서 kind=swarm을 전달하여 생성할 수 있습니다.
sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales", kind='swarm',data=data_BM[:250]);

b. 범주형 분포도

상자 그림

  • 상자 그림은 극단값과 함께 분포의 3사분위수 값을 보여줍니다.
  • "수염"은 하위 및 상위 사분위수의 1.5 IQR 내에 있는 점으로 확장되며, 이 범위를 벗어나는 관찰은 독립적으로 표시됩니다.
  • 이것은 상자 그림의 각 값이 데이터의 실제 관측값에 해당함을 의미합니다.
sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales",kind="box",data=data_BM);

Violin Plots

sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales",kind="violin",data=data_BM);

boxplot

  • 이 스타일의 플롯은 "문자 값"으로 정의되는 많은 수의 분위수를 보여주기 때문에 원래 "문자 값" 플롯이라고 명명되었습니다.
  • 모든 특징이 실제 관찰과 일치하는 분포의 비모수적 표현을 플로팅하는 상자 플롯과 유사합니다.
  • 더 많은 분위수를 표시함으로써 특히 꼬리 부분의 분포 모양에 대한 더 많은 정보를 제공합니다.
sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales",kind="boxen",data=data_BM);

Point plot

sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales",kind="point",data=data_BM);

Bar plots

sns.catplot(x="Outlet_Size", y="Item_Outlet_Sales",kind="bar",data=data_BM);

3. Density Plots

히스토그램 대신 Seaborn이 sn.kdeplot으로 수행하는 커널 밀도 추정을 사용하여 분포를 원활하게 추정할 수 있습니다.

plt.figure(figsize=(10,10))
sns.kdeplot(data_BM['Item_Visibility'], shade=True);
plt.figure(figsize=(10,10))
sns.kdeplot(data_BM['Item_MRP'], shade=True);

히스토그램 및 Density Plots

히스토그램과 KDE는 distplot을 사용하여 결합할 수 있습니다.

plt.figure(figsize=(10,10))
sns.distplot(data_BM['Item_Outlet_Sales']);

4. Pair plots

  • 조인트 플롯을 더 큰 차원의 데이터세트로 일반화하면 페어 플롯으로 끝납니다. 이것은 모든 값 쌍을 서로에 대해 플롯하려는 경우 다차원 데이터 간의 상관 관계를 탐색하는 데 매우 유용합니다.

  • 세 가지 붓꽃 종의 꽃잎과 꽃받침 측정값을 나열하는 잘 알려진 Iris 데이터 세트를 사용하여 이것을 시연할 것입니다.

iris = sns.load_dataset("iris")
iris.head()
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
sns.pairplot(iris, hue='species', height=2.5); # hue는 color

seaborn with matplotlib

1.1 Load data

  • 예제로 사용할 펭귄 데이터를 불러옵니다
  • seaborn에 내장되어 있습니다.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

penguins = sns.load_dataset("penguins")
penguins.head()
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female

1.2 Figure and Axes

  • matplotlib으로 도화지figure를 깔고 축공간axes를 만듭니다.
  • 1 x 2 축공간을 구성합니다.
fig, axes = plt.subplots(ncols=2, figsize=(8,4))

fig.tight_layout()

1.3 plot with matplotlib

  • matplotlib 기능을 이용해서 산점도를 그립니다.
  • x축은 부리 길이 bill length
  • y축은 부리 위 아래 두께 bill depth
  • 색상은 종species로 합니다.
  • Adelie, Chinstrap, Gentoo이 있습니다.
  • 두 축공간 중 왼쪽에만 그립니다.
  • 컬러를 다르게 주기 위해 f-string 포맷을 사용했습니다.
  • f-string 포맷에 대한 설명은 https://blockdmask.tistory.com/429를 참고하세요
fig, axes = plt.subplots(ncols=2,figsize=(8,4))

species_u = penguins["species"].unique()

for i, s in enumerate(species_u):
    axes[0].scatter(penguins["bill_length_mm"].loc[penguins["species"]==s],
                    penguins["bill_depth_mm"].loc[penguins["species"]==s],
                    c=f"C{i}", label=s, alpha=0.3)
    
axes[0].legend(species_u, title="species")
axes[0].set_xlabel("Bill Length (mm)")
axes[0].set_ylabel("Bill Depth (mm)")

# plt.show()
fig.tight_layout()

조금 더 간단히 그리는 방법 matplotlib는 기본적으로 Categorical 변수를 color로 바로 사용하지 못함

penguins["species_codes"] = pd.Categorical(penguins["species"]).codes

fig, axes = plt.subplots(ncols=2,figsize=(8,4))

axes[0].scatter(data=penguins, x="bill_length_mm", y="bill_depth_mm", c="species_codes", alpha=0.3)
<matplotlib.collections.PathCollection at 0x1ee029ed820>

1.4 Plot with seaborn

fig, axes = plt.subplots(ncols=2,figsize=(8,4))

species_u = penguins["species"].unique()

# plot 0 : matplotlib

for i, s in enumerate(species_u):
    axes[0].scatter(penguins["bill_length_mm"].loc[penguins["species"]==s],
                    penguins["bill_depth_mm"].loc[penguins["species"]==s],
                    c=f"C{i}", label=s, alpha=0.3)
    
axes[0].legend(species_u, title="species")
axes[0].set_xlabel("Bill Length (mm)")
axes[0].set_ylabel("Bill Depth (mm)")


# plot 1 : seaborn
sns.scatterplot(x="bill_length_mm", y="bill_depth_mm", hue="species", data=penguins, alpha=0.3, ax=axes[1])
axes[1].set_xlabel("Bill Length (mm)")
axes[1].set_ylabel("Bill Depth (mm)")

fig.tight_layout()
  • 단 세 줄로 거의 동일한 그림이 나왔습니다.
  • scatter plot의 점 크기만 살짝 작습니다.
  • label의 투명도만 살짝 다릅니다.
  • seaborn 명령 scatterplot()을 그대로 사용했습니다.
  • x축과 y축 label도 바꾸었습니다.
  • ax=axes[1] 인자에서 볼 수 있듯, 존재하는 axes에 그림만 얹었습니다.
  • matplotlib 틀 + seaborn 그림 이므로, matplotlib 명령이 모두 통합니다.

1.5 matplotlib + seaborn & seaborn + matplotlib

  • matplotlib과 seaborn이 자유롭게 섞일 수 있습니다.
  • matplotlib 산점도 위에 seaborn 추세선을 얹을 수 있고,
  • seaborn 산점도 위에 matplotlib 중심점을 얹을 수 있습니다.
  • 파이썬 코드는 다음과 같습니다.
fig, axes = plt.subplots(ncols=2, figsize=(8, 4))

species_u = penguins["species"].unique()

# plot 0 : matplotlib + seaborn
for i, s in enumerate(species_u):
    # matplotlib 산점도
    axes[0].scatter(penguins["bill_length_mm"].loc[penguins["species"]==s],
                   penguins["bill_depth_mm"].loc[penguins["species"]==s],
                   c=f"C{i}", label=s, alpha=0.3
                  )
                  
    # seaborn 추세선
    sns.regplot(x="bill_length_mm", y="bill_depth_mm", data=penguins.loc[penguins["species"]==s], 
                scatter=False, ax=axes[0])
    
axes[0].legend(species_u, title="species")
axes[0].set_xlabel("Bill Length (mm)")
axes[0].set_ylabel("Bill Depth (mm)")

# plot 1 : seaborn + matplotlib
# seaborn 산점도
sns.scatterplot(x="bill_length_mm", y="bill_depth_mm", hue="species", data=penguins, alpha=0.3, ax=axes[1])
axes[1].set_xlabel("Bill Length (mm)")
axes[1].set_ylabel("Bill Depth (mm)")

for i, s in enumerate(species_u):
    # matplotlib 중심점
    axes[1].scatter(penguins["bill_length_mm"].loc[penguins["species"]==s].mean(),
                   penguins["bill_depth_mm"].loc[penguins["species"]==s].mean(),
                   c=f"C{i}", alpha=1, marker="x", s=100
                  )

fig.tight_layout()

1.6 seaborn + seaborn + matplotlib

  • 안 될 이유가 없습니다.
  • seaborn scatterplot + seaborn kdeplot + matplotlib text입니다
fig, ax = plt.subplots(figsize=(6,5))

# plot 0: scatter plot
sns.scatterplot(x="bill_length_mm", y="bill_depth_mm", color="k", data=penguins, alpha=0.3, ax=ax, legend=False)

# plot 1: kde plot
sns.kdeplot(x="bill_length_mm", y="bill_depth_mm", hue="species", data=penguins, alpha=0.5, ax=ax, legend=False)

# text:
species_u = penguins["species"].unique()
for i, s in enumerate(species_u):
    ax.text(penguins["bill_length_mm"].loc[penguins["species"]==s].mean(),
            penguins["bill_depth_mm"].loc[penguins["species"]==s].mean(),
            s = s, fontdict={"fontsize":14, "fontweight":"bold","color":"k"}
            )

ax.set_xlabel("Bill Length (mm)")
ax.set_ylabel("Bill Depth (mm)")

fig.tight_layout()